"""The purpose of this module is to transfer input of a certain task
(outputs stored on other nodes) before processing can begin on them.
This module will be used to pull input for the Reduce function. The
difference of this class from the inputfetcher class is that this 
requires the explicit provision of the IP from which the file need to
be fetched; whereas inputfetcher finds where to fetch the file from
by using broadcast queries

"""

import traceback
import time
import shutil
import socket
import os
import ma.log
import ma.const
import ma.net.brdcstsv
from . import outputserver
from . import constants
import copy
import errno



class InputFetcherNoQuery(object):
    """This class is used to fetch input from a list of tuples provided
    
    The InputFetcher's object can be used to request for files and then 
    transfer each of them using TCP. If the file is local, it will be 
    copied locally.
    """
    
    MAP_REDUCE_TYPE = 0
    MRIMPROV_TYPE = 1
    
    
    def __init__(self, output_server, nodetracker, task_id, job_id, fetch_type=MRIMPROV_TYPE):
        """Constructor for initializing the Input Fetcher. It starts the
        broadcast packet sender class 
        """
        #initialize __log for Input Fetcher
        self.__log = ma.log.get_logger("ma.commons")
        
        try:
            self.__output_comm_port = ma.const.XmlData.get_int_data(ma.const.xml_net_output_srv_bind_port)
            
            # file transfer buffer size
            self.__file_transfer_buff_size = ma.const.XmlData.get_int_data(ma.const.xml_net_file_transfer_buff_size)
            
            # server for querying results of the broadcast
            self.__output_server = output_server
            
            # for querying which task outputs are local
            self.__nodetracker = nodetracker
            
            self.__fetch_type = fetch_type
            
            self.__log.info('Initialized the Input Fetcher class')
        except Exception as err_msg:
            self.__log.error("Error while creating Input Fetched object: %s", err_msg)
    
        
    def fetch_input(self, outputs_of_tasks, task_id, job_id):
        """This function is to be called every-time inputs for a reduce
        task are required. Task id is needed for MR so that the hash
        can be matched and hence given the appropriate output. The job_id
        needs to be provided so that the files can be copied in the
        appropriate input directory. This function returns the tuple in
        outputs_of_tasks which couldn't be transferred from the network
        """
        
        try:
            remaining_output_of_tasks = copy.deepcopy(outputs_of_tasks)
            
            # copy to local directory files which are local and update remaining tasks 
            #    that need to be shuffled
            remaining_output_of_tasks = self.__copy_local_task_outputs(remaining_output_of_tasks, task_id, job_id)
            
            self.__log.debug('Remaining outputs after local copy %s', str(remaining_output_of_tasks))
        
            # tasks which need to be removed
            remove_tasks = []
            
            # iterate through all outputs and transfer eachone of them
            for remaining_out_task in remaining_output_of_tasks:
                # get hosting IP
                host_ip = remaining_out_task[3]
                
                self.__log.debug('Fetching output %s hosted at %s', str((remaining_out_task[0], remaining_out_task[1], remaining_out_task[2], task_id)), str(host_ip))
                
                if host_ip != None:
                    # transfer the file
                    self.__log.debug('Transferring file for %s', str(remaining_out_task))
                    
                    result = self.__transfer_file(remaining_out_task, task_id, job_id, host_ip)
                    
                    if result:
                        remove_tasks.append(remaining_out_task)
                        self.__log.debug('Transferred file for %s', str(remaining_out_task))
                    else:
                        self.__log.error('Failed transferring file for %s', str(remaining_out_task))
            
            # remove tasks from remaining task list
            for remove_task in remove_tasks:
                remaining_output_of_tasks.remove(remove_task)
            
            self.__log.debug('Input fetched for job_id %d, task_id %d with these remaining %s', job_id, task_id, str(remaining_output_of_tasks))
            
            # returns the tasks which couldn't be retrieved from the network
            return remaining_output_of_tasks
        
        except Exception as err_msg:
            traceback.print_exception()
            self.__log.error("Error while fetching input for task %d : %s", task_id, err_msg)
            
    
    def __copy_local_task_outputs(self, outputs_of_tasks, task_id, job_id):
        """This function just copies the output files to output directory.
        This is useful one some tasks which need to be shuffled were 
        assigned to this node in the first place
        """
        
        try:
            remove_tasks = []
            
            # input directory where output found should be shifted
            input_dir_path = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_input_temp_dir, job_id)
            
            # iterate through all lists 
            for i in range(len(outputs_of_tasks)):
                task_out = outputs_of_tasks[i]
                filepath = self.__nodetracker.getCompleteTaskOutputPath(task_out[0], task_out[1], task_out[2], task_id)
                
                # check if the input is available
                if filepath != None:
                    input_filepath = input_dir_path + os.sep + os.path.basename(filepath)
                    remove_tasks.append(task_out)
                    # copy from one output directory to input directory
                    shutil.copyfile(filepath, input_filepath)
                    self.__log.debug('Transferred file locally for input %s to %s', filepath, input_filepath)
            
            # remove from the task
            for i in remove_tasks:
                outputs_of_tasks.remove(i)
            
            return outputs_of_tasks
        
        except Exception as err_msg:
            traceback.print_exception()
            self.__log.error('Error while copying inputs locally : %s', err_msg)
    
    
    def __transfer_file(self, task_out_info, task_id, job_id, host_ip):
        """This function is used to transfer files from the server to this
        client. This function returns true on success
        """
        
        try:
            # input directory where output found should be shifted
            input_dir_path = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_input_temp_dir, job_id)
            
            # input filepath where the file needs to be transferred to
            if self.__fetch_type == InputFetcherNoQuery.MRIMPROV_TYPE:
                # if the fetcher made is for MR+, don't worry about any partitioning code
                if constants.MAP == task_out_info[1]:
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename, task_out_info[0], task_out_info[2])
                else:
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_output_filename, task_out_info[0], task_out_info[2])
            else:
                # if the fetcher made is for MapReduce, use the partitioning 
                #    code in filename, and raise error if trying to transfer 
                #    reduce output
                if constants.MAP == task_out_info[1]:
                    # TODO: insert partitioner for task id
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename_with_key, task_out_info[0], task_out_info[2], task_id)
                else:
                    self.__log.error('Trying to transfer a reduce file of job_id %d and task_id %d', task_out_info[0], task_out_info[2])
                    raise IOError('Trying to transfer a reduce file of job_id %d and task_id %d' % (task_out_info[0], task_out_info[2]))
                                
            is_file_transfer_complete = False
            
            # loop until successful transfer of file 
            while is_file_transfer_complete == False:
                try:
                    is_file_transfer_complete = False
                    
                    # open file for writing. And also truncate
                    fd = open(input_write_filepath, 'w+')
                    
                    self.__log.debug('Opened file for transfer %s', input_write_filepath)
                    
                    out_info_msg = str(task_out_info[0]) + str(task_out_info[1]) + str(task_out_info[2]) + '_' + str(task_id)
                    pckt_msg = outputserver.PCKT_MSG % (outputserver.REQUEST_FILE_TAG, out_info_msg)
                    
                    # establish TCP connection
                    self.__log.debug('Establishing connection with the node hosting the file %s', task_out_info)
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect((host_ip, self.__output_comm_port))
        
                    self.__log.debug('Sending request for file: %s to %s', pckt_msg, str((host_ip, self.__output_comm_port)))
                    
                    # send request packet
                    bytes_sent = 0
                    total_to_send = len(pckt_msg)
                    if bytes_sent < total_to_send:
                        sent = sock.send(pckt_msg[bytes_sent:])
                        bytes_sent += sent
                    
                    # transfer file data
                    self.__log.debug('Transferring file %s from %s', task_out_info, host_ip)
                    data = sock.recv(self.__file_transfer_buff_size)
                    
                    # read the data
                    while data:
                        # transfer the buffer read
                        fd.write(data)
                        data = sock.recv(self.__file_transfer_buff_size)
                    
                    is_file_transfer_complete = True
                    
                except socket.error as e:
                    self.__log.error("Socket Error is %s", e)
                    fd.close()
                    
                    # Clean up of socket
                    sock.close()
                    
                    if e[0] == errno.EPIPE:
                        # remote peer disconnected
                        print("Detected remote disconnect")
                    
                    is_file_transfer_complete = False 
                        
                except IOError as err_msg:
                    fd.close()
                    self.__log.error("Could not transfer file... retrying: %s IOERROR", err_msg)
                    is_file_transfer_complete = False
                    
                except Exception as err_msg:
                    fd.close()
                    # Clean up of socket
                    sock.close()
                    self.__log.error("Could not transfer file... retrying: %s" , err_msg)
                    is_file_transfer_complete = False
                    
                else:
                    # if no error encountered
                    
                    # close the file
                    fd.close()
                    
                    # Clean up of socket
                    sock.close()
                
            return True
        
        except Exception as err_msg:
            traceback.print_exception()
            self.__log.error("Error while transferring file : %s", err_msg)
            return False



def test():
    nt = outputserver.DummyNodeTracker()
    outserver = outputserver.OutputServer(nt)
    inpfetcher = InputFetcherNoQuery(outserver, nt)
    out_of_tasks = [(1,'M',1),(1,'M',2),(1,'M',3),(1,'M',4)]
    inpfetcher.fetch_input(out_of_tasks, 1, 1)
    